import math

import torch
from torch.distributions import Normal

from src.envs.base_environment import ContinuousEnvironment


class LineEnvironment(ContinuousEnvironment):
    """
    ### Description

    A bounded 1D environment where the state is a scalar value and the action is a scalar value that is added to the state.
    
    ### Action Space

    | Num | Action | Min | Max                       |
    |-----|--------|-----|---------------------------|
    | 0   | Delta  | 0   | upper_bound-lower_bound   |

    ### Observation Space

    | Num | Observation | Min         | Max         | 
    |-----|-------------|-------------|-------------|
    | 0   | State       | lower_bound | upper_bound |

    ### Rewards

    The reward density is defined as a mixture of Gaussians with means `mus` and variances `variances`.

    r = logsumexp([Normal(mu, sigma).log_prob(x) for mu, sigma in zip(mus, sigmas)])

    ### Policy Parameterisation

    The policy is parameterised as a Gaussian mixture model with `mixture_dim` components.

    ### Arguments

    - mus: List of means of the Gaussians in the mixture model of the reward.
    - variances: List of variances of the Gaussians in the mixture model of the reward.
    - max_policy_std: Maximum standard deviation of the policy.
    - min_policy_std: Minimum standard deviation of the policy.
    - mixture_dim: Number of components in the Gaussian mixture model in the parameterisation of the policy.
    - lower_bound: Lower bound of the state space.
    - upper_bound: Upper bound of the state space.

    """

    def __init__(self, config):
        self._init_required_params(config)
        self.mixture = [
            Normal(m, s) for m, s in zip(self.mus, self.sigmas)
        ]
        super().__init__(config,
                         dim=1,
                         feature_dim=1,
                         angle_dim=[False],
                         action_dim=1,
                         lower_bound=config["env"]["lower_bound"],
                         upper_bound=config["env"]["upper_bound"],
                         mixture_dim=config["env"]["mixture_dim"],
                         output_dim=config["env"]["mixture_dim"]*3,)

    def _init_required_params(self, config):
        required_params = ["mus", "variances", "max_policy_std", "min_policy_std", "lower_bound", "upper_bound"]
        assert all([param in config["env"] for param in required_params]), f"Missing required parameters: {required_params}"
        self.mus = torch.tensor(config["env"]["mus"], device=config["device"])
        self.sigmas = torch.tensor([math.sqrt(v) for v in config["env"]["variances"]], device=config["device"])
        self.variances = torch.tensor(config["env"]["variances"], device=config["device"])
        self.max_policy_std = config["env"]["max_policy_std"]
        self.min_policy_std = config["env"]["min_policy_std"]

    def log_reward(self, x):
        return torch.logsumexp(torch.stack([m.log_prob(x) for m in self.mixture], 0), 0)
    
    def step(self, x, action):
        action = action.squeeze()
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] + action  # Add action delta.
        new_x[:, 1] = x[:, 1] + 1  # Increment step counter.

        return new_x
    
    def backward_step(self, x, action):
        action = action.squeeze()
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] - action
        new_x[:, 1] = x[:, 1] - 1

        return new_x
    
    def compute_initial_action(self, first_state):
        return (first_state - self.init_value)

    def _init_policy_dist(self, param_dict):
        mus, stds, weights = param_dict["mus"], param_dict["stds"], param_dict["weights"]
        mix = torch.distributions.Categorical(weights)
        comp = torch.distributions.Independent(torch.distributions.Normal(mus.unsqueeze(-1), stds.unsqueeze(-1)), 1)
        
        return torch.distributions.MixtureSameFamily(mix, comp)
    
    def postprocess_params(self, params):
        mu_params, std_params, weight_params = params[:, :self.mixture_dim], params[:, self.mixture_dim: 2 * self.mixture_dim], params[:, 2 * self.mixture_dim:]
        # Restrict the policy mean so that the mean state after the action is within the domain
        mus = torch.sigmoid(mu_params) * (self.upper_bound - self.lower_bound) + (self.lower_bound - self.upper_bound) / 2
        # Restrict the policy std to be within the specified bounds
        stds = torch.sigmoid(std_params) * (self.max_policy_std - self.min_policy_std) + self.min_policy_std
        # Normalise the weights of the Gaussian mixture model
        weights = torch.softmax(weight_params, dim=1)
        param_dict = {"mus": mus, "stds": stds, "weights": weights}
        return param_dict   
    
    def add_noise(self, param_dict, off_policy_noise):
        param_dict["stds"] += off_policy_noise

        return param_dict

